import d2l.torch
import numpy as np
import torch
from utils.time_tools import with_timer
from bdtime import tt
from bdtime import show_ls
import os
from utils.my_tools import device, tokenizer, output_dir
from utils.my_tools import generate
from utils.my_tools import ModelCLS, ModelGEN, ModelPPO
from utils.my_tools import test_model_gen, test_model_cls
from utils.my_tools import output_path__gen_model, output_path__cls_model, output_path__ppo_model
from utils.my_tools import conv_to_tensor
from f2_train_gen import test_gen


data_size, batch_size = tokenizer.data_size, tokenizer.batch_size
max_seq_length, num_steps = tokenizer.max_seq_length, tokenizer.num_steps

_, texts_vocab = tokenizer.data_iter, tokenizer.texts_vocab
data_iter, texts_vocab = tokenizer.get_data_iter(
    data_size=data_size,
    batch_size=batch_size,
    is_train=True,
    texts_vocab=texts_vocab,
    prefix=False
)


model_ppo = ModelPPO(torch.load(output_path__gen_model))
model_ppo_ref = ModelPPO(torch.load(output_path__gen_model))


for i in model_ppo_ref.parameters():
    i.requires_grad_(False)


def get_kl(a, b, method='kl'):
    if method == 'kl':
        return a - b

    if method == 'abs':
        return (a - b).abs()

    if method == 'mse':
        return (a - b).square() * 0.5

    if method == 'full':
        return torch.nn.functional.kl_div(a,
                                          b,
                                          log_target=True,
                                          reduction='none')


get_kl(torch.randn(3, 5), torch.zeros(3, 5), method='mse')


from trl.core import clip_by_value, logprobs_from_logits, masked_mean, masked_whiten


class PPOTrainer:

    def __init__(self):
        self.optimizer = torch.optim.Adam(model_ppo.parameters(), lr=1e-5)

    def step(self, question, answer, reward):
        with torch.no_grad():
            #编码
            token = [q.tolist() + a.tolist() for q, a in zip(question, answer)]
            input_ids, attention_mask = tokenizer.batch_pad(token=token)
            del token
            input_ids = torch.LongTensor(input_ids).to(device)
            attention_mask = torch.LongTensor(attention_mask).to(device)

            #question和answer不需要内容,只需要长度信息即可
            lens_q = [len(i) for i in question]
            lens_a = [len(i) for i in answer]
            del question
            del answer

            #根据question计算answer的概率,并计算每个动作的分数
            prob_log, value, mask = self.batched_forward_pass(
                model_ppo, input_ids, attention_mask, lens_q, lens_a)

            #使用ref模型计算概率,这是为了计算kl散度
            prob_log_ref, _, _ = self.batched_forward_pass(
                model_ppo_ref, input_ids, attention_mask, lens_q, lens_a)

            #计算两份概率的kl散度,并融入reward
            reward = self.compute_rewards(reward, prob_log, prob_log_ref, mask)

            #计算delta和target,用于计算loss
            value, delta, target = self.compute_advantages(value, reward, mask)

        #每批数据循环N次模型
        for _ in range(4):
            #每次算一个数据
            for i in range(len(input_ids)):
                #重新计算概率和value
                prob_log_new, value_new, _ = self.batched_forward_pass(
                    model_ppo, input_ids[i].unsqueeze(0),
                    attention_mask[i].unsqueeze(0), [lens_q[i]], [lens_a[i]])

                #根据新旧概率求出变化率,进而求出loss
                #根据target和value的差可以计算出另外一份loss
                loss = self.get_loss(prob_log[i].unsqueeze(0),
                                     value[i].unsqueeze(0), prob_log_new,
                                     value_new, mask[i].unsqueeze(0),
                                     delta[i].unsqueeze(0),
                                     target[i].unsqueeze(0))

                if not loss:
                    continue

                loss.backward()
                #torch.nn.utils.clip_grad_norm_(model_ppo.parameters(), 1.0)
                self.optimizer.step()
                self.optimizer.zero_grad()

    def batched_forward_pass(self, model, input_ids, attention_mask, lens_q, lens_a):
        logits, value = model(input_ids=input_ids,
                              attention_mask=attention_mask)

        #取每个字的概率对数
        prob_log = logprobs_from_logits(logits[:, :-1], input_ids[:, 1:])

        #是预测结果并且不是PAD的位置是1
        mask = torch.zeros_like(attention_mask)
        mask[:, :-1] = attention_mask[:, 1:]
        for i in range(len(input_ids)):
            start = lens_q[i] - 1
            end = start + lens_a[i]
            mask[i, :start] = 0
            mask[i, end:] = 0

        #对最后一个字的预测没有意义,直接丢弃
        value = value[:, :-1]
        mask = mask[:, :-1]

        return prob_log, value, mask

    def compute_rewards(self, reward, prob_log, prob_log_ref, mask):
        reward_kl = []

        for i in range(len(reward)):
            #求两份概率的kl散度
            kl = get_kl(prob_log[i], prob_log_ref[i]) * -0.2

            #把reward加在最后一个字的kl散度上
            if (mask[i] == 0).all():
                #print('all 0')
                idx = 0
            else:
                idx = mask[i].nonzero()[-1].item()
            kl[idx] += reward[i]

            reward_kl.append(kl)

        return torch.stack(reward_kl)

    def compute_advantages(self, value, reward_kl, mask):
        value = value * mask
        reward_kl = reward_kl * mask

        delta = []
        lens = reward_kl.shape[1]

        #从后往前遍历
        for i in reversed(range(lens)):
            #取下一时刻的value,如果已经是最后一个时刻,则value_next是0
            #因为整个循环是从后往前,所以第0次是0,其他时刻取value
            value_next = 0
            if i < lens - 1:
                value_next = value[:, i + 1]

            #value = gamma*下一时刻的value + reward
            #理论上相等,这里的差定义为delta,这里gamma是1,所以省略了
            d = reward_kl[:, i] + value_next - value[:, i]

            #取最后一个delta,如果还没有,则初始化为0
            last_d = 0
            if delta:
                last_d = delta[-1]

            #delta是从后往前传递的,这里的系数衡量了前后动作的因果关联性
            delta.append(d + 0.95 * last_d)

        #翻转顺序
        delta = torch.stack(delta[::-1]).transpose(0, 1)

        #定义target,它估计了理想的value值
        target = delta + value
        delta = masked_whiten(delta, mask)

        return value, delta, target

    def get_loss(self, prob_log, value, prob_log_new, value_new, mask, delta, target):

        #对数概率,相除变相减,取exp后还原为商,即两个模型输出logits的变化率
        ratio = (prob_log_new - prob_log).exp()

        #如果变化率太过于剧烈,可能是发生了震荡,跳过
        if masked_mean(ratio, mask).item() > 10:
            #print('skip', masked_mean(ratio, mask).item())
            return None

        #先算两个value的loss,简单的算mse loss就可以了
        loss_vf1 = (value_new - target)**2
        #数值裁剪,很显然是为了缓解自举
        loss_vf2 = clip_by_value(value_new, value - 0.2, value + 0.2)
        loss_vf2 = (loss_vf2 - target)**2
        #两份loss取大的,还是为了缓解自举
        loss_vf = 0.5 * masked_mean(torch.max(loss_vf1, loss_vf2), mask)

        #计算ppo loss
        loss_surr1 = -delta * ratio
        #数值裁剪,很显然是为了缓解自举
        loss_surr2 = -delta * ratio.clamp(0.8, 1.2)
        loss_surr = masked_mean(torch.max(loss_surr1, loss_surr2), mask)

        return loss_surr + 0.1 * loss_vf


trainer = PPOTrainer()

trainer


model_cls = torch.load(output_path__cls_model)
model_cls.to(device)

for i in model_cls.parameters():
    i.requires_grad_(False)


from utils.my_tools import TestModel

label, question, attention_mask, real_answer = TestModel.get_question(prefix=True, batch_size=16, ret_real_answer=True)

# [len(q) for q in question]
# tokenizer.encoder['=']
# show_ls(label[:3])
# show_ls(question[:3])
# q = question[0]
a = real_answer[0]
# tokenizer.decode(question, '')
# tokenizer.decode(real_answer, '')
# tokenizer.decode(a, '')
# show_ls([tokenizer.decode(q.tolist() + a.tolist(), '') for q, a in list(zip(question, real_answer))[:10]])


# 如果question的长度确定,这里可以转换成批运算
def get_answer(question):
    _answer = [generate(model_ppo.model_gen, i.unsqueeze(0)) for i in question]

    # 裁剪,只要生成的部分
    answer = [a[0, len(q):] for q, a in zip(question, _answer)]

    if 0:
        my_q = ['请问: 2011年03月13日=']
        # tokenizer.default_prefix
        input_ids, attention_mask = tokenizer.batch_pad(text=my_q)
        from utils.my_tools import conv_to_tensor
        input_ids = conv_to_tensor(input_ids, torch.LongTensor)
        # input_ids = tokenizer.add_bos(input_ids)
        _answer = [generate(model_ppo.model_gen, i.unsqueeze(0)) for i in input_ids]
        show_ls(tokenizer.decode(_answer, ''))

        # ------------ else
        # question[:5]
        # _answer = [generate(model_ppo.model_gen, i.unsqueeze(0)) for i in question[:5]]
        _answer
        tokenizer.pad_token_id, tokenizer.bos_token_id, tokenizer.eos_token_id
        show_ls(tokenizer.decode(_answer, ''))
        show_ls(tokenizer.decode(question[:5], ''))

    if 0:
        # --- error!
        assert False, 'error!'
        my_q = tokenizer.decode(question[:5], '', skip_spacial_symbols=True)
        # my_q = ['请问: 2011年03月13日=', '请问: 1981年05月21日=']
        test_gen(test_times=3, my_questions=my_q)
        show_ls(tokenizer.decode(_answer, ''))

        show_ls(tokenizer.decode(_answer[:5], ""))
        show_ls(tokenizer.decode(question[:5], ""))
        show_ls(tokenizer.decode(real_answer[:5], ''))
        show_ls([tokenizer.decode(q.tolist() + [space]*3 + r_a.tolist() + [space]*3 + a.tolist(), '') for q, r_a, a in list(zip(question, real_answer, answer))[:5]])
        # a = _answer[0]
        # a[0, len(q):]

    return answer


answer = get_answer(question)

# len(answer)
# answer[:10]
# show_ls(tokenizer.decode(answer[:5], concat_symbol=""))


def get_reward(question, answer, label):
    token = [q.tolist() + a.tolist() for q, a in zip(question, answer)]
    input_ids, attention_mask = tokenizer.batch_pad(token=token)
    # show_ls(tokenizer.decode(token[-5:], concat_symbol=""))
    # len(input_ids[0]) == len(attention_mask[0])
    # show_ls(tokenizer.decode(answer[-5:], concat_symbol=""))

    input_ids = torch.LongTensor(input_ids).to(device)
    attention_mask = torch.LongTensor(attention_mask).to(device)

    with torch.no_grad():
        # input_ids.shape, attention_mask.shape
        logits = model_cls(input_ids=input_ids, attention_mask=attention_mask)

    res = logits.gather(1, label.reshape(-1, 1)).squeeze(1)
    return res


reward = get_reward(question, answer, label)

max_epoch = 2000
from tqdm import tqdm

# tq_i = tqdm(total=max_epoch)

space = tokenizer.encoder[' ']
show_ls([tokenizer.decode(q.tolist() + [space] * 3 + r_a.tolist() + [space] * 3 + a.tolist(), '') for q, r_a, a in list(zip(question, real_answer, answer))[:5]])
print('--- rewards:', reward[:5])

print("\n======== 快捷键[ctrl + t]结束训练并保存模型")
tt.sleep(5)

with with_timer(f"模型训练", tt) as wt:
    for epoch in range(max_epoch):
        # tq_i.update(1)
        if tt.stop('ctrl + t'):
            print('***** break by user! current epoch:', epoch, '******')
            break

        # label, question, attention_mask = get_question()
        label, question, attention_mask, real_answer = TestModel.get_question(prefix=True, batch_size=16, ret_real_answer=True)

        answer = get_answer(question)
        reward = get_reward(question, answer, label)

        trainer.step(question, answer, reward)

        if epoch % 10 == 0:
            mean_reward = round(reward.mean().item(), 3)
            # print(epoch, reward.mean().item())

            for _, q, a, r, _real_answer in zip(range(3), question, answer, reward, real_answer):
                q = tokenizer.decode(q.tolist(), '')
                a = tokenizer.decode(a.tolist(), '')
                r = round(r.item(), 3)
                print('q, a, r --- ', q, a, r, '--- real:', tokenizer.decode(_real_answer.tolist(), ''))

            # space = tokenizer.encoder[' ']
            # show_ls(
            #     [tokenizer.decode(q.tolist() + [space] * 3 + r_a.tolist() + [space] * 3 + a.tolist(), '') for q, r_a, a
            #      in list(zip(question, real_answer, answer))[:5]])
            # print('--- rewards:', reward[:5])

            # my_questions = ['bar: 2011年03月13日=', '请问: 1981年05月21日=']
            # test_gen(model_gen=model_ppo.model_gen, test_times=3, my_questions=my_questions)
            # test_gen(model_gen=model_ppo.model_gen, test_times=3, prefix=True, batch_size=1)

            my_questions = [
                'dot: 2222年02月22日=',
                'slash: 1999年09月29日=',
                'abbr: 1981年05月21日=',
            ]
            # my_answers = [
            #     '2022.02.22',
            #     '1999/09/29',
            #     '21/05/1981',
            # ]
            test_gen(test_times=1, my_questions=my_questions)
            wt.show(f'~~~~~~~~~ epoch: {epoch} / {max_epoch}, mean_reward: {mean_reward}', reset_cost=True)


model_ppo.to('cpu')

if os.path.exists(output_path__gen_model):
    tt.tqdm_sleep(f'********** Warning: model_gen[{output_path__gen_model}]已存在! 将进行覆盖操作!', T=10)
torch.save(model_ppo, output_path__ppo_model)









